package com.demo;

import com.alibaba.fastjson.JSON;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Media;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;

@Slf4j
@RunWith(SpringRunner.class)
@SpringBootTest
public class OllamaTest {
    @Resource
    private OllamaChatModel ollamaChatModel;
    @Value("classpath:data/dog.png")
    private org.springframework.core.io.Resource imageResource;
    @Resource
    private TokenTextSplitter tokenTextSplitter;
    @Resource(name = "ollamaPgVectorStore")
    private PgVectorStore pgVectorStore;

    @Test
    public void test_model() {
        ChatOptions defaultOptions = ollamaChatModel.getDefaultOptions();
    }

    //普通对话
    @Test
    public void test_call() {
        ChatResponse response = ollamaChatModel.call(new Prompt(
                "1+1=?",
                OllamaOptions.builder().model("deepseek-r1:1.5b").build()));
        log.info("测试结果(call)：{}", JSON.toJSONString(response));
    }

    //gpu图像识别
    @Test
    public void test_call_images() {
        UserMessage userMessage = new UserMessage("请描述这张图片的主要内容，并说明图中物品的可能用途。",
                new Media(MimeType.valueOf(MimeTypeUtils.IMAGE_PNG_VALUE),
                        imageResource));
        ChatResponse response = ollamaChatModel.call(new Prompt(
                userMessage,
                OllamaOptions.builder().model("deepseek-r1:1.5b").build()));
        log.info("测试结果(call)：{}", JSON.toJSONString(response));
    }

    //流式对话
    @Test
    public void test_stream() throws InterruptedException {
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Flux<ChatResponse> stream = ollamaChatModel.stream(new Prompt(
                "1+1=?",
                OllamaOptions.builder().model("deepseek-r1:1.5b").build()));
        //输出结果
        stream.subscribe(
                //consumer
                chatResponse -> {
                    AssistantMessage output = chatResponse.getResult().getOutput();
                    log.info("测试结果(stream)：{}", JSON.toJSONString(output));
                },
                //errorConsumer
                Throwable::printStackTrace,
                //complateConsumer
                ()->{
                    countDownLatch.countDown();
                    log.info("测试结果(stream): done!");
                }
        );
        countDownLatch.await();
    }

    //上传文件并解析知识库
    @Test
    public void upload() {
        TikaDocumentReader reader = new TikaDocumentReader("classpath:data/file1.txt");
        List<Document> documents = reader.get();
        List<Document> documentSplitList = tokenTextSplitter.apply(documents);
        documents.forEach(doc->doc.getMetadata().put("knowledge", "知识库名称v3"));
        documentSplitList.forEach(doc->doc.getMetadata().put("knowledge", "知识库名称v3"));
        pgVectorStore.accept(documentSplitList);
        log.info("上传完成");
    }

    //测试断乎rag是否接入了向量知识库
    @Test
    public void chatRag() {
        String msg = "张浩今年几岁";
        String SYSTEM_PROMPT = """
                Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately.
                If unsure, simply state that you don't know.
                Another thing you need to note is that your reply must be in Chinese!
                DOCUMENTS:
                    {documents}
                """;
        SearchRequest request = SearchRequest.builder()
                .query(msg)
                .topK(5)
                .filterExpression("knowledge == '知识库名称v3'")
                .build();
        List<Document> documents = pgVectorStore.similaritySearch(request);
        String documentsCollectors = documents.stream().map(Document::getText).collect(Collectors.joining());
        Message ragMessage = new SystemPromptTemplate(SYSTEM_PROMPT)
                .createMessage(Map.of("documents", documentsCollectors));
        List<Message> messages = new ArrayList<>();
        messages.add(ragMessage);
        ChatResponse response = ollamaChatModel.call(new Prompt(
                messages,
                OllamaOptions.builder().model("deepseek-r1:1.5b").build()
        ));
        log.info("测试结果: {}", JSON.toJSONString(response));
    }
}
